// SPDX-License-Identifier: GPL-2.0
#include "vmlinux.h"
#include <bpf/bpf_core_read.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>
#include "mountsnoop.h"
#include "compat.bpf.h"

#define MAX_ENTRIES 10240

const volatile pid_t target_pid = 0;

struct
{
    __uint(type, BPF_MAP_TYPE_HASH);
    __uint(max_entries, MAX_ENTRIES);
    __type(key, __u32);
    __type(value, struct arg);
} args SEC(".maps");

static __always_inline int probe_entry(const char *src, const char *dest,
                                       const char *fs, __u64 flags,
                                       const char *data, enum op op)
{
    __u64 pid_tgid = bpf_get_current_pid_tgid();
    __u32 pid = pid_tgid >> 32;
    __u32 tid = (__u32)pid_tgid;
    struct arg arg = {};

    if (target_pid && target_pid != pid)
        return 0;

    arg.ts = bpf_ktime_get_ns();
    arg.flags = flags;
    arg.src = src;
    arg.dest = dest;
    arg.fs = fs;
    arg.data = data;
    arg.op = op;

    bpf_map_update_elem(&args, &tid, &arg, BPF_ANY);
    return 0;
}

static __always_inline int probe_exit(void *ctx, int ret)
{
    __u64 pid_tgid = bpf_get_current_pid_tgid();
    __u32 pid = pid_tgid >> 32;
    __u32 tid = (__u32)pid_tgid;
    struct task_struct *task;
    struct event *eventp;
    struct arg *argp;

    argp = bpf_map_lookup_elem(&args, &tid);
    if (!argp)
        return 0;

    eventp = reserve_buf(sizeof(*eventp));
    if (!eventp)
        goto cleanup;

    task = (struct task_struct *)bpf_get_current_task();
    eventp->delta = bpf_ktime_get_ns() - argp->ts;
    eventp->flags = argp->flags;
    eventp->pid = pid;
    eventp->tid = tid;
    eventp->mnt_ns = BPF_CORE_READ(task, nsproxy, mnt_ns, ns.inum);
    eventp->ret = ret;
    eventp->op = argp->op;
    bpf_get_current_comm(&eventp->comm, sizeof(eventp->comm));
    if (argp->src)
        bpf_probe_read_user_str(eventp->src, sizeof(eventp->src), argp->src);
    else
        eventp->src[0] = '\0';
    if (argp->dest)
        bpf_probe_read_user_str(eventp->dest, sizeof(eventp->dest), argp->dest);
    else
        eventp->dest[0] = '\0';
    if (argp->fs)
        bpf_probe_read_user_str(eventp->fs, sizeof(eventp->fs), argp->fs);
    else
        eventp->fs[0] = '\0';
    if (argp->data)
        bpf_probe_read_user_str(eventp->data, sizeof(eventp->data), argp->data);
    else
        eventp->data[0] = '\0';

    submit_buf(ctx, eventp, sizeof(*eventp));

cleanup:
    bpf_map_delete_elem(&args, &tid);
    return 0;
}

SEC("tracepoint/syscalls/sys_enter_mount")
int mount_entry(struct trace_event_raw_sys_enter *ctx)
{
    const char *src = (const char *)ctx->args[0];
    const char *dest = (const char *)ctx->args[1];
    const char *fs = (const char *)ctx->args[2];
    const char *data = (const char *)ctx->args[4];
    __u64 flags = (__u64)ctx->args[3];

    return probe_entry(src, dest, fs, flags, data, MOUNT);
}

SEC("tracepoint/syscalls/sys_exit_mount")
int mount_exit(struct trace_event_raw_sys_exit *ctx)
{
    return probe_exit(ctx, (int)ctx->ret);
}

SEC("tracepoint/syscalls/sys_enter_umount")
int umount_entry(struct trace_event_raw_sys_enter *ctx)
{
    const char *dest = (const char *)ctx->args[0];
    __u64 flags = (__u64)ctx->args[1];

    return probe_entry(NULL, dest, NULL, flags, NULL, UMOUNT);
}

SEC("tracepoint/syscalls/sys_exit_umount")
int umount_exit(struct trace_event_raw_sys_exit *ctx)
{
    return probe_exit(ctx, (int)ctx->ret);
}

char LICENSE[] SEC("license") = "GPL";